@echo off
chcp 65001 >nul
echo ========================================
echo SD3.5 Medium 训练和采样测试脚本
echo ========================================

:: 设置环境变量
set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
set CUDA_LAUNCH_BLOCKING=1

:: 基础路径定义
set WORKSPACE=/workspace
set AI_DIR=%WORKSPACE%/AI
set PYTHON_PATH=%AI_DIR%/.train/bin/python
set SCRIPT_PATH=%AI_DIR%/ComfyUI/custom_nodes/comfyui_lora_train/sd_scripts/sd3_train_network.py

:: 模型路径定义
set MODELS_DIR=%AI_DIR%/ComfyUI/models
set SD3_MODEL=%MODELS_DIR%/checkpoints/sd3.5_medium.safetensors
set CLIP_L_MODEL=%MODELS_DIR%/clip/clip_l.safetensors
set CLIP_G_MODEL=%MODELS_DIR%/clip/clip_g.safetensors
set T5XXL_MODEL=%MODELS_DIR%/text_encoders/t5xxl_fp8_e4m3fn.safetensors

:: 数据集和输出路径
set TRAIN_DATA_DIR=%WORKSPACE%/training_data_all/girl
set OUTPUT_DIR=%WORKSPACE%/training_data_all/output
set SAMPLE_PROMPTS=%WORKSPACE%/training_data_all/prompts.txt

:: 检查必要文件是否存在
echo 检查必要文件...
if not exist "%SD3_MODEL%" (
    echo 错误: SD3模型文件不存在: %SD3_MODEL%
    pause
    exit /b 1
)

if not exist "%CLIP_L_MODEL%" (
    echo 错误: CLIP-L模型文件不存在: %CLIP_L_MODEL%
    pause
    exit /b 1
)

if not exist "%CLIP_G_MODEL%" (
    echo 错误: CLIP-G模型文件不存在: %CLIP_G_MODEL%
    pause
    exit /b 1
)

if not exist "%T5XXL_MODEL%" (
    echo 错误: T5XXL模型文件不存在: %T5XXL_MODEL%
    pause
    exit /b 1
)

if not exist "%TRAIN_DATA_DIR%" (
    echo 错误: 训练数据目录不存在: %TRAIN_DATA_DIR%
    pause
    exit /b 1
)

if not exist "%SAMPLE_PROMPTS%" (
    echo 错误: 采样提示词文件不存在: %SAMPLE_PROMPTS%
    pause
    exit /b 1
)

:: 创建输出目录
if not exist "%OUTPUT_DIR%" (
    echo 创建输出目录: %OUTPUT_DIR%
    mkdir "%OUTPUT_DIR%"
)

echo.
echo ========================================
echo 开始SD3.5 Medium训练和采样测试
echo ========================================
echo 模型: %SD3_MODEL%
echo 训练数据: %TRAIN_DATA_DIR%
echo 输出目录: %OUTPUT_DIR%
echo 采样提示词: %SAMPLE_PROMPTS%
echo.

:: 执行训练命令
echo 执行训练命令...
"%PYTHON_PATH%" "%SCRIPT_PATH" ^
    --pretrained_model_name_or_path="%SD3_MODEL%" ^
    --train_data_dir="%TRAIN_DATA_DIR%" ^
    --output_dir="%OUTPUT_DIR%" ^
    --output_name="sd3_test_lora" ^
    --clip_l="%CLIP_L_MODEL%" ^
    --clip_g="%CLIP_G_MODEL%" ^
    --t5xxl="%T5XXL_MODEL%" ^
    --network_dim=16 ^
    --network_alpha=8 ^
    --learning_rate=5e-5 ^
    --max_train_epochs=1 ^
    --train_batch_size=1 ^
    --resolution=512 ^
    --optimizer_type=AdamW ^
    --lr_scheduler=linear ^
    --mixed_precision=bf16 ^
    --save_every_n_steps=50 ^
    --save_model_as=safetensors ^
    --gradient_checkpointing ^
    --max_data_loader_n_workers=2 ^
    --persistent_data_loader_workers ^
    --seed=42 ^
    --max_grad_norm=1.0 ^
    --random_crop ^
    --flip_aug ^
    --fp8_base ^
    --blocks_to_swap=8 ^
    --disable_mmap_load_safetensors ^
    --network_train_unet_only ^
    --cache_latents ^
    --cache_text_encoder_outputs ^
    --sample_every_n_steps=10 ^
    --sample_prompts="%SAMPLE_PROMPTS%" ^
    --sample_sampler=ddim ^
    --sample_at_first

echo.
echo ========================================
echo 训练完成！
echo 输出目录: %OUTPUT_DIR%
echo ========================================
pause 